#!/usr/bin/env python3

# Copyright (c) 2024-2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import argparse
import tempfile
import queue
import threading

# find direct_io in the same directory as this program
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import direct_io

def copy_worker(fd_src, fd_dst, buffer_size, workpile, thread_buffer, fs_block_size, thread_id):
    """Worker function to copy chunks from the workpile."""
    while True:
        try:
            # Get the next chunk to copy.
            offset, size = workpile.get_nowait()
        except Exception:
            # Workpile is empty, so exit.
            return

        # pread_direct and pwrite_direct have their own retry loop so we don't need one here
        bytes_read = direct_io.pread(fd_src, thread_buffer, size, offset, fs_block_size, thread_id)
        assert bytes_read == size, "pread returned different size than requested"
        bytes_written = direct_io.pwrite(fd_dst, thread_buffer, bytes_read, offset, fs_block_size)
        assert bytes_written == size, "pwrite returned different size than requested"
        workpile.task_done()
            

def fastcp(src_path, dst_path, num_threads, buffer_size):
    """Copy a file using multiple threads with a workpile algorithm."""

    fd_src, fd_dst = None, None
    print ("opening direct")
    try:
        # Open files as file descriptors with O_DIRECT.
        try:
            fd_src = os.open(src_path, os.O_RDONLY | os.O_DIRECT)
        except OSError as e:
            raise RuntimeError(f"Failed to open source file '{src_path}': {e}")

        try:
            fd_dst = os.open(dst_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC | os.O_DIRECT)
        except OSError as e:
            raise RuntimeError(f"Failed to open destination file '{dst_path}': {e}")

        try:
            # Get the size of the source file.
            file_size = os.fstat(fd_src).st_size
            print("file size is", file_size)

            # Get file system block sizes.
            fs_block_size = os.fstatvfs(fd_src).f_bsize
            dst_block_size = os.fstatvfs(fd_dst).f_bsize

            # get each thread an aligned buffer
            
            alignment = 2*1024*1024 # 2M is the size of Linux Huge Pages, so this should always be enough
            assert alignment >= max(fs_block_size, dst_block_size, 512), "alignment smaller than fs block size"

            thread_buffers, raw_buf = direct_io.allocate_aligned_buffers(buffer_size, alignment, num_threads)

            # Create the workpile with all file chunks.
            chunk_size = buffer_size
            workpile = queue.Queue()
            for offset in range(0, file_size, chunk_size):
                workpile.put((offset, min(chunk_size, file_size - offset)))

            # Start worker threads.
            threads = []
            for i in range(num_threads):
                thread = threading.Thread(target=copy_worker,
                                          args=(fd_src, fd_dst, buffer_size, workpile,
                                                thread_buffers[i], fs_block_size, i))
                threads.append(thread)
                thread.start()

            # Wait for all threads to finish.
            for thread in threads:
                thread.join()

            # Set the size of the destination file.
            os.ftruncate(fd_dst, file_size)
        except Exception as e:
            raise RuntimeError(f"Error during file copy operation: {e}")
    finally:
        if fd_src is not None:
            os.close(fd_src)
        if fd_dst is not None:
            os.close(fd_dst)

def parse_and_validate_args():
    prog_name = os.path.basename(sys.argv[0])  # Extracts script name

    parser = argparse.ArgumentParser(
        description=f"Copy SOURCE to DEST, or multiple SOURCE(s) to DIRECTORY.",
        usage=f"""{prog_name} [OPTION]... SOURCE DEST
       {prog_name} [OPTION]... SOURCE... DIRECTORY
       {prog_name} [OPTION]... -t DIRECTORY SOURCE..."""
    )

    parser.add_argument("-t", "--target-directory", metavar="DIRECTORY",
                        help="copy all SOURCE arguments into DIRECTORY")
    parser.add_argument("-r", "--recursive", action="store_true",
                        help="copy directories recursively")
    parser.add_argument("-f", "--force", action="store_true",
                        help="overwrite existing destination files")
    parser.add_argument("files", nargs="+",
                        help="source file(s) and destination (or just sources if -t is used)")
    parser.add_argument("-n", "--num-threads", type=int, default=16,
                        help="Number of threads (default: 16)")
    parser.add_argument("-b", "--buffer-size", type=int, default=256 * 1024 * 1024,
                        help="Buffer size in bytes (default: 256MiB, must be a multiple of 2MiB)")

    try:
        args = parser.parse_args()
    except SystemExit:
        sys.exit(f"{prog_name}: missing file operand")

    # Determine sources and destination
    if args.target_directory:
        args.sources = args.files
        args.destination = args.target_directory
    else:
        args.sources = args.files[:-1]
        args.destination = args.files[-1]

    # Ensure at least one source is provided
    if not args.sources:
        # the user provided only one file name, which we've recorded as destination, but user
        # usually meant that to be the source file name, and forgot to give a target
        parser.error(f"{prog_name}: missing destination file operand after '{args.destination}'")

    # Validate sources
    for src in args.sources:
        if not os.path.exists(src):
            sys.exit(f"{prog_name}: cannot stat '{src}': No such file or directory")
        if os.path.isdir(src) and not args.recursive:
            sys.exit(f"{prog_name}: -r not specified; omitting directory '{src}'")

    # Validate destination
    if args.target_directory or len(args.sources) > 1:
        if not os.path.exists(args.destination) or not os.path.isdir(args.destination):
            sys.exit(f"{prog_name}: target '{args.destination}' is not a directory")
    else:
        if os.path.exists(args.destination) and os.path.isfile(args.destination):
            if os.path.isdir(args.sources[0]):  # Only valid if both are files
                sys.exit(f"{prog_name}: cannot overwrite '{args.destination}' with directory '{args.sources[0]}'")

    return args

def list_relative_files(root):
    """Return all file paths under 'root' as relative paths, sorted alphabetically.
    os.walk(followlinks=True) allows following symlinked directories,
    but it does not guard against cycles. This may cause infinite loops
    if symlinks form a directory cycle.
    """
    file_list = []
    # FIXME: os.walk(followlinks) doesn't protect against cycles
    # to fix this we'd need to write our own version that did a depth-first
    # spanning tree.
    for dirpath, _, filenames in os.walk(root, followlinks=True):
        for fname in filenames:
            full_path = os.path.join(dirpath, fname)
            rel_path = os.path.relpath(full_path, root)
            file_list.append(rel_path)
    return sorted(file_list)

def plan_copy_operations(args):
    """Return list of (src_abs, dst_abs, size_bytes) file tuples to copy."""
    file_jobs = []
    dst_root = os.path.abspath(args.destination)

    if not os.path.isdir(dst_root):  # case 1: single file copy
        src_abs = os.path.abspath(args.sources[0])
        dst_abs = os.path.abspath(args.destination)
        size = os.path.getsize(src_abs)
        file_jobs.append((src_abs, dst_abs, size))
    else:
        for src in args.sources:
            src_abs = os.path.abspath(src)
            base = os.path.basename(src.rstrip("/"))
            if os.path.isdir(src):
                for relpath in list_relative_files(src):
                    full_src = os.path.join(src_abs, relpath)
                    full_dst = os.path.join(dst_root, base, relpath)
                    size = os.path.getsize(full_src)
                    file_jobs.append((full_src, full_dst, size))
            else:
                dst_path = os.path.join(dst_root, base)
                size = os.path.getsize(src_abs)
                file_jobs.append((src_abs, dst_path, size))

    return file_jobs

if __name__ == "__main__":
    args = parse_and_validate_args()
    # round buffer size up to next multiple of 2 MiB
    if args.buffer_size % (2 * 1024 * 1024) != 0:
        args.buffer_size = direct_io.round_up(args.buffer_size, 2 * 1024 * 1024)
        print(f"rounding buffer size up to {args.buffer_size} ({args.buffer_size/(1024*1024)} MiB)")

    print(f"Sources: {args.sources}")
    print(f"Destination: {args.destination}")

    file_jobs = plan_copy_operations(args)

    if not args.force:
        for _, dst, _ in file_jobs:
            if os.path.exists(dst):
                sys.exit(f"{os.path.basename(sys.argv[0])}: will not overwrite existing file '{dst}' without --force")

    for src, dst, size in file_jobs:
        print(f"Copy {src} -> {dst} ({size} bytes)")
        os.makedirs(os.path.dirname(dst), exist_ok=True)
        fastcp(src, dst, args.num_threads, args.buffer_size)
